import inspect
import os.path
import time
from functools import wraps

import torch


class Profiler:
    _instance = None

    def __new__(cls, *args, **kwargs):
        if cls._instance is None:
            cls._instance = super(Profiler, cls).__new__(cls)
            cls._instance._records = {}  # Raw (start, end) time records
            cls._instance.enabled = False
        return cls._instance

    def disable(self):
        self.enabled = False

    def enable(self):
        self.enabled = True

    def record(self, key, start_time, end_time):
        """Record raw timing data for a given key.

        Args:
            key (str): Identifier for the operation being timed
            start_time (float): Operation start time
            end_time (float): Operation end time
        """
        if key not in self._records:
            self._records[key] = []
        self._records[key].append((start_time, end_time))

    def get_raw_metrics(self):
        """Get the raw timing records.

        Returns:
            dict: Dictionary mapping keys to lists of (start_time, end_time) tuples
        """
        return self._records

    def get_metric_summary(self, skip_first_n=0):
        """Compute summary statistics for each metric, optionally ignoring initial records.

        Args:
            skip_first_n (int): Number of initial records to ignore for each key
                               (useful for excluding warmup iterations)

        Returns:
            dict: Dictionary containing summary statistics for each metric key
        """
        summary = {}

        for key, records in self._records.items():
            if len(records) <= skip_first_n:
                continue

            # Skip warmup records
            valid_records = records[skip_first_n:]

            # Calculate execution times
            exec_times = [end - start for start, end in valid_records]

            # Get last timing information
            last_start, last_end = valid_records[-1]

            if exec_times:
                mean_time = sum(exec_times) / len(exec_times)

                # Calculate standard deviation
                if len(exec_times) > 1:
                    variance = sum((t - mean_time) ** 2 for t in exec_times) / (len(exec_times) - 1)
                    std_time = variance ** 0.5
                else:
                    std_time = 0.0

                summary[key] = {
                    "call_count": len(valid_records),
                    "mean_time": mean_time,
                    "std_time": std_time,
                    "last_start_time": last_start,
                    "last_end_time": last_end
                }

        return summary

    def reset(self):
        """Reset all stored metrics and records."""
        self._records = {}


class ProfileContext:
    def __init__(self, name=None, use_cuda_sync=True):
        self.name = name
        self.use_cuda_sync = use_cuda_sync
        self.profiler = Profiler()
        self.start_time = None

    def __enter__(self):
        if not self.profiler.enabled:
            return self
        if self.use_cuda_sync and torch.cuda.is_available():
            torch.cuda.synchronize()  # Synchronize GPU
        self.start_time = time.perf_counter()

        # Build key with filename, class, method, and line number
        frame = inspect.currentframe().f_back
        filename = os.path.basename(frame.f_code.co_filename)
        line_number = frame.f_lineno
        class_name = frame.f_locals.get("self", None).__class__.__name__ if "self" in frame.f_locals else None
        method_name = frame.f_code.co_name

        key_parts = []
        if self.name:
            key_parts.append(self.name)
        elif class_name:
            key_parts.append(class_name)
            key_parts.append(method_name)
            key_parts.append(str(line_number))
        elif method_name:
            key_parts.append(filename)
            key_parts.append(method_name)
            key_parts.append(str(line_number))
        else:
            key_parts.append(filename)
            key_parts.append(str(line_number))

        self.key = ":".join(key_parts)
        return self

    def __exit__(self, exc_type, exc_value, traceback):
        if not self.profiler.enabled:
            return
        if self.use_cuda_sync and torch.cuda.is_available():
            torch.cuda.synchronize()  # Synchronize GPU
        end_time = time.perf_counter()
        self.profiler.record(self.key, self.start_time, end_time)


def profile_method(func=None, *, use_cuda_sync=True):
    def decorator(func):
        @wraps(func)
        def wrapper(self, *args, **kwargs):
            profiler = Profiler()
            if not profiler.enabled:
                return func(self, *args, **kwargs)
            # Build key with filename, class, method, and line number
            frame = inspect.currentframe().f_back
            filename = frame.f_code.co_filename
            class_name = self.__class__.__name__
            method_name = func.__name__
            line_number = frame.f_lineno

            key = f"{filename}:{class_name}:{method_name}:{line_number}"

            start_time = time.perf_counter()
            result = func(self, *args, **kwargs)
            if use_cuda_sync and torch.cuda.is_available():
                torch.cuda.synchronize()  # Synchronize GPU
            end_time = time.perf_counter()

            if use_cuda_sync and torch.cuda.is_available():
                torch.cuda.synchronize()  # Synchronize GPU

            # Record metrics
            profiler.record(key, start_time, end_time)

            return result

        return wrapper

    if func is None:
        return decorator
    return decorator(func)


class CustomProfiler:
    def __init__(self, name, use_cuda_sync=True):
        self.name = name
        self.use_cuda_sync = use_cuda_sync
        self.profiler = Profiler()
        self.start_time = None

    def start(self):
        if not self.profiler.enabled:
            return
        if self.use_cuda_sync and torch.cuda.is_available():
            torch.cuda.synchronize()  # Ensure GPU operations are complete
        self.start_time = time.perf_counter()

    def stop(self):
        if not self.profiler.enabled:
            return
        if self.start_time is None:
            return  # Ignore stop without start
        if self.use_cuda_sync and torch.cuda.is_available():
            torch.cuda.synchronize()  # Ensure GPU operations are complete
        end_time = time.perf_counter()
        self.profiler.record(self.name, self.start_time, end_time)
        self.start_time = None

    def reset(self):
        self.start_time = None
